import torch

from nnet.spex_plus import SpEx_Plus


if __name__ == "__main__":
    spex = SpEx_Plus()
    x = torch.randn(3, 30000)
    aux = torch.randn(3, 7000)
    aux_len = 7000
    y = spex(x, aux, aux_len)
    pass
